#!/usr/bin/env python3
# -*- coding:utf-8 -*-

import socket
import threading
import time
import random
import struct
import logging
from typing import List
import cv2
import numpy as np
from aigc_task_manager_v2 import AIGCTaskManager, AIGCSocket, AIGCTask, AIGCTaskDef, g_sam_embeddings, g_cached_images


def random_vcode(k=4):
    vcode = ''.join(random.choices(
        ['z', 'y', 'x', 'w', 'v', 'u', 't', 's', 'r', 'q', 'p', 'o', 'n', 'm', 'l', 'k', 'j', 'i', 'h', 'g', 'f', 'e',
         'd', 'c', 'b', 'a', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], k=k))
    return vcode


class SAMPipeline(threading.Thread):
    def __init__(self, client_key, client_socket, _server):
        threading.Thread.__init__(self)
        self.client_key = client_key
        self.client_socket = client_socket
        self._server = _server
        self.vcode = random_vcode()
        self.running = True

    def run(self):
        last_data = b''
        msg_cnt = 0
        while self.running:
            try:
                if len(last_data) >= 12:
                    data = last_data
                else:
                    data = last_data + self.client_socket.recv(1024 * 1024)  # 1Mb

                if len(data) == 0:
                    self.running = False
                    break

                if data[: 2] == b'\xFA\xFC':
                    if len(data) >= 12:
                        req_id = struct.unpack('I', data[4: 8])[0]
                        n_payloads = struct.unpack('I', data[8: 12])[0]

                        reloading = False
                        while len(data) < 12 + n_payloads:
                            data_ex = self.client_socket.recv(1024 * 1024)  # 1Mb
                            if len(data_ex) >= 2 and data_ex[: 2] == b'\xFA\xFC':  # 中途遇到新的帧，重新开始
                                last_data = data_ex
                                reloading = True
                                break
                            if len(data_ex) == 0:  # 断开连接，退出
                                self.running = False
                                break
                            data += data_ex

                        if not reloading and self.running and len(data) >= 12 + n_payloads:
                            msg_cnt += 1
                            # print("  data ({}): {}".format(msg_cnt, len(data[:10 + n_payloads])))
                            data_input = data[0: 12 + n_payloads]
                            if data_input[2] == 0x00:  # 传Time
                                time_ms = struct.unpack('d', data_input[12: 20])[0]
                                print("time in ms = {}, id = {}".format(time_ms, req_id))
                                return_msg = b'\xFB\xFD'
                                return_msg += b'\x00\xFF' + data_input[4: 8]
                                return_msg += struct.pack('I', 8)
                                return_msg += struct.pack('d', time_ms)
                                self.client_socket.send(return_msg)
                            elif data_input[2] == 0x01:  # 传图像数据
                                image = np.asarray(bytearray(data_input[12: 12 + n_payloads]), dtype="uint8")
                                image = cv2.imdecode(image, cv2.IMREAD_COLOR)
                                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                                task = AIGCTask(self.client_key, AIGCTaskDef.SAM_Backbone)
                                task.set_req_id(req_id)
                                task.set_image(image)
                                self._server.task_manager.add_task(task)
                            elif data_input[2] == 0x02:  # 传单点坐标
                                pt_x = struct.unpack('f', data_input[12: 16])[0]
                                pt_y = struct.unpack('f', data_input[16: 20])[0]
                                # print("0x02 pt: {:.2f}, {:.2f}".format(pt_x, pt_y))
                                task = AIGCTask(self.client_key, AIGCTaskDef.SAM_Point)
                                task.set_req_id(req_id)
                                task.set_point(np.array([pt_x, pt_y]))
                                self._server.task_manager.add_task(task)
                            elif data_input[2] == 0x03:  # 传矩形框+点序列
                                n_bboxes = struct.unpack('I', data_input[12: 16])[0]
                                n_pts = struct.unpack('I', data_input[16: 20])[0]
                                # print("0x03 n_bboxes: {}, n_pts: {}".format(n_bboxes, n_pts))
                                if len(data_input) >= 20 + n_bboxes * 16 + n_pts * 9:
                                    input_box = np.array([])
                                    if n_bboxes > 0:
                                        input_box = np.array([
                                            struct.unpack('f', data_input[20: 24])[0],
                                            struct.unpack('f', data_input[24: 28])[0],
                                            struct.unpack('f', data_input[28: 32])[0],
                                            struct.unpack('f', data_input[32: 36])[0]
                                        ])
                                    input_point = []
                                    input_label = []
                                    for j in range(n_pts):
                                        ind = 20 + n_bboxes * 16 + j * 9
                                        input_point.extend([
                                            struct.unpack('f', data_input[ind: ind + 4])[0],
                                            struct.unpack('f', data_input[ind + 4: ind + 8])[0]
                                        ])
                                        if data_input[ind + 8] == 0x00:
                                            input_label.append(0)
                                        else:
                                            input_label.append(1)
                                    input_point = np.array(input_point).reshape(n_pts, 2)
                                    input_label = np.array(input_label)
                                    task = AIGCTask(self.client_key, AIGCTaskDef.SAM_Box_Points_Labels)
                                    task.set_req_id(req_id)
                                    task.set_box_points_labels(input_box, input_point, input_label)
                                    self._server.task_manager.add_task(task)
                            elif data_input[2] == 0x04:  # CacheImage
                                image = np.asarray(bytearray(data_input[12: 12 + n_payloads]), dtype="uint8")
                                image = cv2.imdecode(image, cv2.IMREAD_COLOR)
                                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                                task = AIGCTask(self.client_key, AIGCTaskDef.CacheImage)
                                task.set_req_id(req_id)
                                task.set_image(image)
                                self._server.task_manager.add_task(task)
                            elif data_input[2] == 0x05:  # GroundingDINO
                                text = str(data_input[12: 12 + n_payloads], encoding='utf-8')
                                task = AIGCTask(self.client_key, AIGCTaskDef.GroundingDINO)
                                task.set_req_id(req_id)
                                task.set_text(text)
                                self._server.task_manager.add_task(task)

                            last_data = data[12 + n_payloads:]
                    else:
                        last_data = data
                else:
                    last_data = data[1:]

            except Exception as e:
                print(e)
                self.running = False

        if not self.running:
            self.client_socket.close()
            self._server.quit(self.client_key)


class SAMTcpServer(threading.Thread, AIGCSocket):

    def __init__(self, port=9000):
        threading.Thread.__init__(self)
        socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        socket_server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        socket_server.bind(('', port))
        socket_server.listen(5)
        print('Start listening on port {} ...'.format(port))
        self.socket_server = socket_server
        self.listening = True
        self.connected_clients = dict()
        self.task_manager = AIGCTaskManager(self)
        self.task_manager.start()

    def quit(self, client_key=None):
        if client_key is None:
            for k, c in self.connected_clients.items():
                c.close()
            self.listening = False
        else:
            del self.connected_clients[client_key]
            if client_key in g_sam_embeddings.keys():
                del g_sam_embeddings[client_key]
            if client_key in g_cached_images.keys():
                del g_cached_images[client_key]
            # print(g_sam_embeddings.keys())
        print("Now clients remain: {}".format(len(self.connected_clients)))

    def run(self):
        while self.listening:
            client_socket, client_address = self.socket_server.accept()
            client_key = random_vcode()
            print('Got client: [{}]'.format(client_key))

            self.connected_clients[client_key] = client_socket
            pipeline = SAMPipeline(client_key, client_socket, self)
            pipeline.start()

    def response_info(self, task: AIGCTask):
        return_msg = b'\xFB\xFD'
        if task.task == AIGCTaskDef.SAM_Backbone:
            return_msg += b'\x01\xFF' + struct.pack('I', task.req_id)
            return_msg += struct.pack('I', 0)
            self.connected_clients[task.uid].send(return_msg)
        elif task.task == AIGCTaskDef.SAM_Point:
            return_msg += b'\x02\xFF' + struct.pack('I', task.req_id)
            return_msg += struct.pack('I', len(task.results))
            return_msg += task.results
            self.connected_clients[task.uid].send(return_msg)
            # print('return AIGCTaskDef.SAM_Point: {}'.format(task.results))
        elif task.task == AIGCTaskDef.SAM_Box_Points_Labels:
            return_msg += b'\x03\xFF' + struct.pack('I', task.req_id)
            return_msg += struct.pack('I', len(task.results))
            return_msg += task.results
            self.connected_clients[task.uid].send(return_msg)
        elif task.task == AIGCTaskDef.CacheImage:
            return_msg += b'\x04\xFF' + struct.pack('I', task.req_id)
            return_msg += struct.pack('I', 0)
            self.connected_clients[task.uid].send(return_msg)
        elif task.task == AIGCTaskDef.GroundingDINO:
            return_msg += b'\x05\xFF' + struct.pack('I', task.req_id)
            n_payloads = 0
            n_payloads += 4
            bl_msg = b''
            for i in range(len(task.results['boxes'])):
                bl_msg += struct.pack(
                    'fffff',
                    task.results['boxes'][i, 0],
                    task.results['boxes'][i, 1],
                    task.results['boxes'][i, 2],
                    task.results['boxes'][i, 3],
                    task.results['logits'][i]
                )
                n_payloads += 20
            n_payloads += 4
            ph_msg = b''
            for i in range(len(task.results['phrases'])):
                ph_msg += bytes(task.results['phrases'][i], encoding='utf-8')
                ph_msg += b','
                n_payloads += len(task.results['phrases'][i]) + 1
            return_msg += struct.pack('I', n_payloads)
            return_msg += struct.pack('I', len(bl_msg))
            return_msg += bl_msg
            return_msg += struct.pack('I', len(ph_msg))
            return_msg += ph_msg
            self.connected_clients[task.uid].send(return_msg)

    def response_error(self, task: AIGCTask):
        pass


if __name__ == '__main__':
    server = SAMTcpServer(9093)
    server.start()
